{
  "cells": [
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "E2Cu87RMWw-P"
      },
      "source": [
        "### 1. Install and import the required packages"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "4Px8aik4VaOY"
      },
      "outputs": [],
      "source": [
        "!pip install transformers sentence-transformers datasets"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RUsTXFi1bNRI"
      },
      "outputs": [],
      "source": [
        "from datasets import load_dataset\n",
        "from sentence_transformers import SentenceTransformer, models\n",
        "from transformers import BertTokenizer\n",
        "from transformers import get_linear_schedule_with_warmup\n",
        "import torch\n",
        "from torch.optim import AdamW\n",
        "from torch.utils.data import DataLoader\n",
        "from tqdm import tqdm\n",
        "import time\n",
        "import datetime\n",
        "import random\n",
        "import numpy as np\n",
        "import pandas as pd"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "zMdAdDQbzWmC"
      },
      "source": [
        "### 2. Use Google Colab's GPU for training"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "wB7TNNSrziMu",
        "outputId": "53715022-a7af-439f-f978-637799295f85"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "There are 1 GPU(s) available.\n",
            "We will use the GPU: Tesla T4\n"
          ]
        }
      ],
      "source": [
        "if torch.cuda.is_available():    \n",
        "    device = torch.device(\"cuda\")\n",
        "    print(f'There are {torch.cuda.device_count()} GPU(s) available.')\n",
        "    print('We will use the GPU:', torch.cuda.get_device_name(0))\n",
        "else:\n",
        "    print('No GPU available, using the CPU instead.')\n",
        "    device = torch.device(\"cpu\")"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "kQ1Eel-3W-5b"
      },
      "source": [
        "### **3.** Load and preview the Semantic Textual Similarity Benchmark (STSB) dataset"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "mgwlDDjtWM71"
      },
      "outputs": [],
      "source": [
        "# Load the English version of the STSB dataset\n",
        "dataset = load_dataset(\"stsb_multi_mt\", \"en\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "BtUWgi0h_DjR",
        "outputId": "bcd36c5b-7a37-4c8c-8bb5-8a46e7ed4d5c"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "DatasetDict({\n",
            "    train: Dataset({\n",
            "        features: ['sentence1', 'sentence2', 'similarity_score'],\n",
            "        num_rows: 5749\n",
            "    })\n",
            "    test: Dataset({\n",
            "        features: ['sentence1', 'sentence2', 'similarity_score'],\n",
            "        num_rows: 1379\n",
            "    })\n",
            "    dev: Dataset({\n",
            "        features: ['sentence1', 'sentence2', 'similarity_score'],\n",
            "        num_rows: 1500\n",
            "    })\n",
            "})\n"
          ]
        }
      ],
      "source": [
        "print(dataset)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "FEHZl4WeWv6r",
        "outputId": "69885fad-1282-48e8-ab5e-29da8c548a85"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "A sample from the STSB dataset's training split:\n",
            "{'sentence1': 'A man is slicing potatoes.', 'sentence2': 'A woman is peeling potato.', 'similarity_score': 2.200000047683716}\n"
          ]
        }
      ],
      "source": [
        "print(\"A sample from the STSB dataset's training split:\")\n",
        "print(dataset['train'][98])"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "OjMKsIuxYv6D"
      },
      "source": [
        "### **4.** Define the dataset loader class\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "f2Hc2uwabgJa"
      },
      "outputs": [],
      "source": [
        "# Instantiate the BERT tokenizer\n",
        "# You can use larger variants of the model, here we're using the base model\n",
        "tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "uEI1p5-SaM8t"
      },
      "outputs": [],
      "source": [
        "class STSBDataset(torch.utils.data.Dataset):\n",
        "\n",
        "    def __init__(self, dataset):\n",
        "\n",
        "        # Normalize the similarity scores in the dataset\n",
        "        similarity_scores = [i['similarity_score'] for i in dataset]\n",
        "        self.normalized_similarity_scores = [i/5.0 for i in similarity_scores]\n",
        "        self.first_sentences = [i['sentence1'] for i in dataset]\n",
        "        self.second_sentences = [i['sentence2'] for i in dataset]\n",
        "        self.concatenated_sentences = [[str(x), str(y)] for x,y in zip(self.first_sentences, self.second_sentences)]\n",
        "\n",
        "    def __len__(self):\n",
        "\n",
        "        return len(self.concatenated_sentences)\n",
        "\n",
        "    def get_batch_labels(self, idx):\n",
        "\n",
        "        return torch.tensor(self.normalized_similarity_scores[idx])\n",
        "\n",
        "    def get_batch_texts(self, idx):\n",
        "\n",
        "        return tokenizer(self.concatenated_sentences[idx], padding='max_length', max_length=128, truncation=True, return_tensors=\"pt\")\n",
        "\n",
        "    def __getitem__(self, idx):\n",
        "\n",
        "        batch_texts = self.get_batch_texts(idx)\n",
        "        batch_y = self.get_batch_labels(idx)\n",
        "\n",
        "        return batch_texts, batch_y\n",
        "\n",
        "\n",
        "def collate_fn(texts):\n",
        "\n",
        "    input_ids = texts['input_ids']\n",
        "    attention_masks = texts['attention_mask']\n",
        "\n",
        "    features = [{'input_ids': input_id, 'attention_mask': attention_mask}\n",
        "                for input_id, attention_mask in zip(input_ids, attention_masks)]\n",
        "\n",
        "    return features"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "w9ICUkr20JbP"
      },
      "source": [
        "### 5. Define the model class based on BERT"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "EgTYEHC8b7kb"
      },
      "outputs": [],
      "source": [
        "class BertForSTS(torch.nn.Module):\n",
        "\n",
        "    def __init__(self):\n",
        "\n",
        "        super(BertForSTS, self).__init__()\n",
        "        self.bert = models.Transformer('bert-base-uncased', max_seq_length=128)\n",
        "        self.pooling_layer = models.Pooling(self.bert.get_word_embedding_dimension())\n",
        "        self.sts_bert = SentenceTransformer(modules=[self.bert, self.pooling_layer])\n",
        "\n",
        "    def forward(self, input_data):\n",
        "        output = self.sts_bert(input_data)['sentence_embedding']\n",
        "        return output"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "yMNCebmb4Hlt"
      },
      "outputs": [],
      "source": [
        "# Instantiate the model and move it to GPU\n",
        "model = BertForSTS()\n",
        "model.to(device)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "IXqIA_D_2nYC"
      },
      "source": [
        "### 6. Define the Cosine Similarity loss function"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "ty7Q630Ob96f"
      },
      "outputs": [],
      "source": [
        "class CosineSimilarityLoss(torch.nn.Module):\n",
        "\n",
        "    def __init__(self,  loss_fn=torch.nn.MSELoss(), transform_fn=torch.nn.Identity()):\n",
        "        super(CosineSimilarityLoss, self).__init__()\n",
        "        self.loss_fn = loss_fn\n",
        "        self.transform_fn = transform_fn\n",
        "        self.cos_similarity = torch.nn.CosineSimilarity(dim=1)\n",
        "\n",
        "    def forward(self, inputs, labels):\n",
        "        emb_1 = torch.stack([inp[0] for inp in inputs])\n",
        "        emb_2 = torch.stack([inp[1] for inp in inputs])\n",
        "        outputs = self.transform_fn(self.cos_similarity(emb_1, emb_2))\n",
        "        return self.loss_fn(outputs, labels.squeeze())"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "B688H4qY26ZG"
      },
      "source": [
        "### 7. Prepare the training and validation data split"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "PrQvEJgC4VeB",
        "outputId": "2ce3100a-727a-4909-9481-7d6ff0464c12"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "5,749 training samples\n",
            "1,500 validation samples\n"
          ]
        }
      ],
      "source": [
        "train_ds = STSBDataset(dataset['train'])\n",
        "val_ds = STSBDataset(dataset['dev'])\n",
        "\n",
        "# Create a 90-10 train-validation split.\n",
        "train_size = len(train_ds)\n",
        "val_size = len(val_ds)\n",
        "\n",
        "print('{:>5,} training samples'.format(train_size))\n",
        "print('{:>5,} validation samples'.format(val_size))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "eUPorlzExygm"
      },
      "outputs": [],
      "source": [
        "batch_size = 8\n",
        "\n",
        "train_dataloader = DataLoader(\n",
        "            train_ds,  # The training samples.\n",
        "            num_workers = 4,\n",
        "            batch_size = batch_size, # Use this batch size.\n",
        "            shuffle=True # Select samples randomly for each batch\n",
        "        )\n",
        "\n",
        "validation_dataloader = DataLoader(\n",
        "            val_ds,\n",
        "            num_workers = 4,\n",
        "            batch_size = batch_size # Use the same batch size\n",
        "        )"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "5avkJtGn2-al"
      },
      "source": [
        "### 8. Define the Optimizer and Scheduler"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "lB_HcVbl3EZw"
      },
      "outputs": [],
      "source": [
        "optimizer = AdamW(model.parameters(),\n",
        "                  lr = 1e-6)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "RVT3cA_-3NPP"
      },
      "outputs": [],
      "source": [
        "epochs = 8\n",
        "\n",
        "# Total number of training steps is [number of batches] x [number of epochs]. \n",
        "total_steps = len(train_dataloader) * epochs\n",
        "\n",
        "scheduler = get_linear_schedule_with_warmup(optimizer, \n",
        "                                            num_warmup_steps = 0,\n",
        "                                            num_training_steps = total_steps)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "zyIxF_7J3ep5"
      },
      "source": [
        "### 9. Define a helper function for formatting the elapsed training time as `hh:mm:ss`"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "JH7_0ASp3oDW"
      },
      "outputs": [],
      "source": [
        "# Takes a time in seconds and returns a string hh:mm:ss\n",
        "def format_time(elapsed):\n",
        "    # Round to the nearest second.\n",
        "    elapsed_rounded = int(round((elapsed)))\n",
        "    \n",
        "    # Format as hh:mm:ss\n",
        "    return str(datetime.timedelta(seconds=elapsed_rounded))"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "jJFhpUJp92Qe"
      },
      "source": [
        "### 10. Define the training function, and start the training loop"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "vdeUXU915NE5"
      },
      "outputs": [],
      "source": [
        "def train():\n",
        "  seed_val = 42\n",
        "\n",
        "  criterion = CosineSimilarityLoss()\n",
        "  criterion = criterion.to(device)\n",
        "\n",
        "  random.seed(seed_val)\n",
        "  torch.manual_seed(seed_val)\n",
        "\n",
        "  # We'll store a number of quantities such as training and validation loss, \n",
        "  # validation accuracy, and timings.\n",
        "  training_stats = []\n",
        "  total_t0 = time.time()\n",
        "\n",
        "  for epoch_i in range(0, epochs):\n",
        "      \n",
        "      # ========================================\n",
        "      #               Training\n",
        "      # ========================================\n",
        "\n",
        "      print(\"\")\n",
        "      print('======== Epoch {:} / {:} ========'.format(epoch_i + 1, epochs))\n",
        "      print('Training...')\n",
        "\n",
        "      t0 = time.time()\n",
        "\n",
        "      total_train_loss = 0\n",
        "\n",
        "      model.train()\n",
        "\n",
        "      # For each batch of training data...\n",
        "      for train_data, train_label in tqdm(train_dataloader):\n",
        "\n",
        "          train_data['input_ids'] = train_data['input_ids'].to(device)\n",
        "          train_data['attention_mask'] = train_data['attention_mask'].to(device)\n",
        "\n",
        "          train_data = collate_fn(train_data)\n",
        "          model.zero_grad()\n",
        "\n",
        "          output = [model(feature) for feature in train_data]\n",
        "\n",
        "          loss = criterion(output, train_label.to(device))\n",
        "          total_train_loss += loss.item()\n",
        "\n",
        "          loss.backward()\n",
        "          torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)\n",
        "          optimizer.step()\n",
        "          scheduler.step()\n",
        "\n",
        "      \n",
        "      # Calculate the average loss over all of the batches.\n",
        "      avg_train_loss = total_train_loss / len(train_dataloader)            \n",
        "      \n",
        "      # Measure how long this epoch took.\n",
        "      training_time = format_time(time.time() - t0)\n",
        "\n",
        "      print(\"\")\n",
        "      print(\"  Average training loss: {0:.5f}\".format(avg_train_loss))\n",
        "      print(\"  Training epoch took: {:}\".format(training_time))\n",
        "          \n",
        "      # ========================================\n",
        "      #               Validation\n",
        "      # ========================================\n",
        "\n",
        "      print(\"\")\n",
        "      print(\"Running Validation...\")\n",
        "\n",
        "      t0 = time.time()\n",
        "\n",
        "      model.eval()\n",
        "\n",
        "      total_eval_accuracy = 0\n",
        "      total_eval_loss = 0\n",
        "      nb_eval_steps = 0\n",
        "\n",
        "      # Evaluate data for one epoch\n",
        "      for val_data, val_label in tqdm(validation_dataloader):\n",
        "\n",
        "          val_data['input_ids'] = val_data['input_ids'].to(device)\n",
        "          val_data['attention_mask'] = val_data['attention_mask'].to(device)\n",
        "\n",
        "          val_data = collate_fn(val_data)\n",
        "\n",
        "          with torch.no_grad():        \n",
        "              output = [model(feature) for feature in val_data]\n",
        "\n",
        "          loss = criterion(output, val_label.to(device))\n",
        "          total_eval_loss += loss.item()\n",
        "\n",
        "      # Calculate the average loss over all of the batches.\n",
        "      avg_val_loss = total_eval_loss / len(validation_dataloader)\n",
        "      \n",
        "      # Measure how long the validation run took.\n",
        "      validation_time = format_time(time.time() - t0)\n",
        "      \n",
        "      print(\"  Validation Loss: {0:.5f}\".format(avg_val_loss))\n",
        "      print(\"  Validation took: {:}\".format(validation_time))\n",
        "\n",
        "      # Record all statistics from this epoch.\n",
        "      training_stats.append(\n",
        "          {\n",
        "              'epoch': epoch_i + 1,\n",
        "              'Training Loss': avg_train_loss,\n",
        "              'Valid. Loss': avg_val_loss,\n",
        "              'Training Time': training_time,\n",
        "              'Validation Time': validation_time\n",
        "          }\n",
        "      )\n",
        "\n",
        "  print(\"\")\n",
        "  print(\"Training complete!\")\n",
        "\n",
        "  print(\"Total training took {:} (h:mm:ss)\".format(format_time(time.time()-total_t0)))\n",
        "\n",
        "  return model, training_stats"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "CoWW_TnZgSRf"
      },
      "outputs": [],
      "source": [
        "# Launch the training\n",
        "model, training_stats = train()"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 331
        },
        "id": "nEgMWBU7fzXh",
        "outputId": "2adcb8b2-7fb3-422e-d08e-cf701c0240cf"
      },
      "outputs": [
        {
          "data": {
            "text/html": [
              "\n",
              "  <div id=\"df-471c3966-19db-41fd-9c49-961803816527\">\n",
              "    <div class=\"colab-df-container\">\n",
              "      <div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>Training Loss</th>\n",
              "      <th>Valid. Loss</th>\n",
              "      <th>Training Time</th>\n",
              "      <th>Validation Time</th>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>epoch</th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "      <th></th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>1</th>\n",
              "      <td>0.032639</td>\n",
              "      <td>0.037972</td>\n",
              "      <td>0:05:29</td>\n",
              "      <td>0:00:28</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>0.030737</td>\n",
              "      <td>0.035472</td>\n",
              "      <td>0:05:28</td>\n",
              "      <td>0:00:28</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>3</th>\n",
              "      <td>0.027920</td>\n",
              "      <td>0.033640</td>\n",
              "      <td>0:05:29</td>\n",
              "      <td>0:00:28</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>0.025090</td>\n",
              "      <td>0.032185</td>\n",
              "      <td>0:05:29</td>\n",
              "      <td>0:00:28</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>5</th>\n",
              "      <td>0.023217</td>\n",
              "      <td>0.030802</td>\n",
              "      <td>0:05:27</td>\n",
              "      <td>0:00:28</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>6</th>\n",
              "      <td>0.021199</td>\n",
              "      <td>0.030223</td>\n",
              "      <td>0:05:29</td>\n",
              "      <td>0:00:28</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>7</th>\n",
              "      <td>0.019567</td>\n",
              "      <td>0.029389</td>\n",
              "      <td>0:05:28</td>\n",
              "      <td>0:00:28</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>8</th>\n",
              "      <td>0.017866</td>\n",
              "      <td>0.028664</td>\n",
              "      <td>0:05:29</td>\n",
              "      <td>0:00:28</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>\n",
              "      <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-471c3966-19db-41fd-9c49-961803816527')\"\n",
              "              title=\"Convert this dataframe to an interactive table.\"\n",
              "              style=\"display:none;\">\n",
              "        \n",
              "  <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
              "       width=\"24px\">\n",
              "    <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
              "    <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
              "  </svg>\n",
              "      </button>\n",
              "      \n",
              "  <style>\n",
              "    .colab-df-container {\n",
              "      display:flex;\n",
              "      flex-wrap:wrap;\n",
              "      gap: 12px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert {\n",
              "      background-color: #E8F0FE;\n",
              "      border: none;\n",
              "      border-radius: 50%;\n",
              "      cursor: pointer;\n",
              "      display: none;\n",
              "      fill: #1967D2;\n",
              "      height: 32px;\n",
              "      padding: 0 0 0 0;\n",
              "      width: 32px;\n",
              "    }\n",
              "\n",
              "    .colab-df-convert:hover {\n",
              "      background-color: #E2EBFA;\n",
              "      box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
              "      fill: #174EA6;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert {\n",
              "      background-color: #3B4455;\n",
              "      fill: #D2E3FC;\n",
              "    }\n",
              "\n",
              "    [theme=dark] .colab-df-convert:hover {\n",
              "      background-color: #434B5C;\n",
              "      box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
              "      filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
              "      fill: #FFFFFF;\n",
              "    }\n",
              "  </style>\n",
              "\n",
              "      <script>\n",
              "        const buttonEl =\n",
              "          document.querySelector('#df-471c3966-19db-41fd-9c49-961803816527 button.colab-df-convert');\n",
              "        buttonEl.style.display =\n",
              "          google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
              "\n",
              "        async function convertToInteractive(key) {\n",
              "          const element = document.querySelector('#df-471c3966-19db-41fd-9c49-961803816527');\n",
              "          const dataTable =\n",
              "            await google.colab.kernel.invokeFunction('convertToInteractive',\n",
              "                                                     [key], {});\n",
              "          if (!dataTable) return;\n",
              "\n",
              "          const docLinkHtml = 'Like what you see? Visit the ' +\n",
              "            '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
              "            + ' to learn more about interactive tables.';\n",
              "          element.innerHTML = '';\n",
              "          dataTable['output_type'] = 'display_data';\n",
              "          await google.colab.output.renderOutput(dataTable, element);\n",
              "          const docLink = document.createElement('div');\n",
              "          docLink.innerHTML = docLinkHtml;\n",
              "          element.appendChild(docLink);\n",
              "        }\n",
              "      </script>\n",
              "    </div>\n",
              "  </div>\n",
              "  "
            ],
            "text/plain": [
              "       Training Loss  Valid. Loss Training Time Validation Time\n",
              "epoch                                                          \n",
              "1           0.032639     0.037972       0:05:29         0:00:28\n",
              "2           0.030737     0.035472       0:05:28         0:00:28\n",
              "3           0.027920     0.033640       0:05:29         0:00:28\n",
              "4           0.025090     0.032185       0:05:29         0:00:28\n",
              "5           0.023217     0.030802       0:05:27         0:00:28\n",
              "6           0.021199     0.030223       0:05:29         0:00:28\n",
              "7           0.019567     0.029389       0:05:28         0:00:28\n",
              "8           0.017866     0.028664       0:05:29         0:00:28"
            ]
          },
          "execution_count": 31,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# Create a DataFrame from our training statistics\n",
        "df_stats = pd.DataFrame(data=training_stats)\n",
        "\n",
        "# Use the 'epoch' as the row index\n",
        "df_stats = df_stats.set_index('epoch')\n",
        "\n",
        "# Display the table\n",
        "df_stats"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "X7ahIyP4zsXp",
        "outputId": "ddd2fa70-5a34-4db3-b6ee-b784d59bfb2d"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "WARNING:datasets.builder:Found cached dataset stsb_multi_mt (/root/.cache/huggingface/datasets/stsb_multi_mt/en/1.0.0/a5d260e4b7aa82d1ab7379523a005a366d9b124c76a5a5cf0c4c5365458b0ba9)\n"
          ]
        }
      ],
      "source": [
        "test_dataset = load_dataset(\"stsb_multi_mt\", name=\"en\", split=\"test\")\n",
        "\n",
        "# Prepare the data\n",
        "first_sent = [i['sentence1'] for i in test_dataset]\n",
        "second_sent = [i['sentence2'] for i in test_dataset]\n",
        "full_text = [[str(x), str(y)] for x,y in zip(first_sent, second_sent)]"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wD7oPneMkUhe"
      },
      "outputs": [],
      "source": [
        "model.eval()\n",
        "\n",
        "def predict_similarity(sentence_pair):\n",
        "  \n",
        "  test_input = tokenizer(sentence_pair, padding='max_length', max_length=128, truncation=True, return_tensors=\"pt\").to(device)\n",
        "  test_input['input_ids'] = test_input['input_ids']\n",
        "  test_input['attention_mask'] = test_input['attention_mask']\n",
        "  del test_input['token_type_ids']\n",
        "\n",
        "  output = model(test_input)\n",
        "  sim = torch.nn.functional.cosine_similarity(output[0], output[1], dim=0).item()\n",
        "\n",
        "  return sim"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "e-lGkcofz6hS",
        "outputId": "dd20141d-0496-4426-a97d-0c020612106d"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Sentence 1: A cat is walking around a house.\n",
            "Sentence 2: A woman is peeling potato.\n",
            "Predicted similarity score: 0.01\n"
          ]
        }
      ],
      "source": [
        "example_1 = full_text[100]\n",
        "print(f\"Sentence 1: {example_1[0]}\")\n",
        "print(f\"Sentence 2: {example_1[1]}\")\n",
        "print(f\"Predicted similarity score: {round(predict_similarity(example_1), 2)}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "ViwfU0M2DOgh",
        "outputId": "e677ea0a-4ac8-4d38-e0d8-06baa71bbcb9"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Sentence 1: Two men are playing football.\n",
            "Sentence 2: Two men are practicing football.\n",
            "Predicted similarity score: 0.84\n"
          ]
        }
      ],
      "source": [
        "example_2 = full_text[130]\n",
        "print(f\"Sentence 1: {example_2[0]}\")\n",
        "print(f\"Sentence 2: {example_2[1]}\")\n",
        "print(f\"Predicted similarity score: {round(predict_similarity(example_2), 2)}\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "sGn-H7ARDnBG",
        "outputId": "ea5b057d-40f4-4c9c-896e-ebe6223a6635"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Sentence 1: It varies by the situation.\n",
            "Sentence 2: This varies by institution.\n",
            "Predicted similarity score: 0.6\n"
          ]
        }
      ],
      "source": [
        "example_3 = full_text[812]\n",
        "print(f\"Sentence 1: {example_3[0]}\")\n",
        "print(f\"Sentence 2: {example_3[1]}\")\n",
        "print(f\"Predicted similarity score: {round(predict_similarity(example_3), 2)}\")"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "_XovRH0VkXXs"
      },
      "source": [
        "### Last but not least, save your model!"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "Om3wskAQkaJP"
      },
      "outputs": [],
      "source": [
        "PATH = 'your/path/here'\n",
        "torch.save(model.state_dict(), PATH)"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "wCe1I2soj-Kj"
      },
      "outputs": [],
      "source": [
        "# In order to load the model\n",
        "# First, you have to create an instance of the model's class\n",
        "# And use the saving path for the loading\n",
        "# Don't forget to set the model to the evaluation state using .eval()\n",
        "\n",
        "model = BertForSTS()\n",
        "model.load_state_dict(torch.load(PATH))\n",
        "model.eval()"
      ]
    }
  ],
  "metadata": {
    "accelerator": "GPU",
    "colab": {
      "gpuType": "T4",
      "provenance": []
    },
    "gpuClass": "standard",
    "kernelspec": {
      "display_name": "Python 3",
      "name": "python3"
    },
    "language_info": {
      "name": "python",
      "version": "3.9.12"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
